{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n    @name: King\\n    @title: Stochastic Training for Graph Convolutional Networks\\n    @url: https://github.com/dmlc/dgl/tree/master/examples/pytorch/sampling\\n    \\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "    @name: King\n",
    "    @title: Stochastic Training for Graph Convolutional Networks\n",
    "    @url: https://github.com/dmlc/dgl/tree/master/examples/pytorch/sampling\n",
    "    \n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stochastic Training for Graph Convolutional Networks\n",
    "\n",
    "- Paper : [Control Variate](https://arxiv.org/abs/1710.10568)\n",
    "- Paper : [Skip Connection](https://arxiv.org/abs/1809.05343)\n",
    "- Author's code : [https://github.com/thu-ml/stochastic_gcn](https://github.com/thu-ml/stochastic_gcn)\n",
    "\n",
    "## Dependencies\n",
    "\n",
    "- PyTorch 0.4.1+\n",
    "- requests\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一、包引入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse, time, math\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from functools import partial\n",
    "import dgl\n",
    "import dgl.function as fn\n",
    "from dgl import DGLGraph\n",
    "from dgl.data import register_data_args, load_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NodeUpdate(nn.Module):\n",
    "    def __init__(self, in_feats, out_feats, activation=None, test=False, concat=False):\n",
    "        super(NodeUpdate, self).__init__()\n",
    "        self.linear = nn.Linear(in_feats, out_feats)\n",
    "        self.activation = activation\n",
    "        self.concat = concat\n",
    "        self.test = test\n",
    "\n",
    "    def forward(self, node):\n",
    "        h = node.data['h']\n",
    "        if self.test:\n",
    "            h = h * node.data['norm']\n",
    "        h = self.linear(h)\n",
    "        # skip connection\n",
    "        if self.concat:\n",
    "            h = torch.cat((h, self.activation(h)), dim=1)\n",
    "        elif self.activation:\n",
    "            h = self.activation(h)\n",
    "        return {'activation': h}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCNSampling(nn.Module):\n",
    "    def __init__(self,\n",
    "                 in_feats,\n",
    "                 n_hidden,\n",
    "                 n_classes,\n",
    "                 n_layers,\n",
    "                 activation,\n",
    "                 dropout):\n",
    "        super(GCNSampling, self).__init__()\n",
    "        self.n_layers = n_layers\n",
    "        if dropout != 0:\n",
    "            self.dropout = nn.Dropout(p=dropout)\n",
    "        else:\n",
    "            self.dropout = None\n",
    "        self.layers = nn.ModuleList()\n",
    "        # input layer\n",
    "        skip_start = (0 == n_layers-1)\n",
    "        self.layers.append(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))\n",
    "        # hidden layers\n",
    "        for i in range(1, n_layers):\n",
    "            skip_start = (i == n_layers-1)\n",
    "            self.layers.append(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))\n",
    "        # output layer\n",
    "        self.layers.append(NodeUpdate(2*n_hidden, n_classes))\n",
    "\n",
    "    def forward(self, nf):\n",
    "        nf.layers[0].data['activation'] = nf.layers[0].data['features']\n",
    "\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            h = nf.layers[i].data.pop('activation')\n",
    "            if self.dropout:\n",
    "                h = self.dropout(h)\n",
    "            nf.layers[i].data['h'] = h\n",
    "            nf.block_compute(i,\n",
    "                             fn.copy_src(src='h', out='m'),\n",
    "                             lambda node : {'h': node.mailbox['m'].mean(dim=1)},\n",
    "                             layer)\n",
    "\n",
    "        h = nf.layers[-1].data.pop('activation')\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCNInfer(nn.Module):\n",
    "    def __init__(self,\n",
    "                 in_feats,\n",
    "                 n_hidden,\n",
    "                 n_classes,\n",
    "                 n_layers,\n",
    "                 activation):\n",
    "        super(GCNInfer, self).__init__()\n",
    "        self.n_layers = n_layers\n",
    "        self.layers = nn.ModuleList()\n",
    "        # input layer\n",
    "        skip_start = (0 == n_layers-1)\n",
    "        self.layers.append(NodeUpdate(in_feats, n_hidden, activation, test=True, concat=skip_start))\n",
    "        # hidden layers\n",
    "        for i in range(1, n_layers):\n",
    "            skip_start = (i == n_layers-1)\n",
    "            self.layers.append(NodeUpdate(n_hidden, n_hidden, activation, test=True, concat=skip_start))\n",
    "        # output layer\n",
    "        self.layers.append(NodeUpdate(2*n_hidden, n_classes, test=True))\n",
    "\n",
    "    def forward(self, nf):\n",
    "        nf.layers[0].data['activation'] = nf.layers[0].data['features']\n",
    "\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            h = nf.layers[i].data.pop('activation')\n",
    "            nf.layers[i].data['h'] = h\n",
    "            nf.block_compute(i,\n",
    "                             fn.copy_src(src='h', out='m'),\n",
    "                             fn.sum(msg='m', out='h'),\n",
    "                             layer)\n",
    "\n",
    "        h = nf.layers[-1].data.pop('activation')\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(args):\n",
    "    # load and preprocess dataset\n",
    "    data = load_data(args)\n",
    "\n",
    "    if args.self_loop and not args.dataset.startswith('reddit'):\n",
    "        data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])\n",
    "\n",
    "    train_nid = np.nonzero(data.train_mask)[0].astype(np.int64)\n",
    "    test_nid = np.nonzero(data.test_mask)[0].astype(np.int64)\n",
    "\n",
    "    features = torch.FloatTensor(data.features)\n",
    "    labels = torch.LongTensor(data.labels)\n",
    "    train_mask = torch.ByteTensor(data.train_mask)\n",
    "    val_mask = torch.ByteTensor(data.val_mask)\n",
    "    test_mask = torch.ByteTensor(data.test_mask)\n",
    "    in_feats = features.shape[1]\n",
    "    n_classes = data.num_labels\n",
    "    n_edges = data.graph.number_of_edges()\n",
    "\n",
    "    n_train_samples = train_mask.sum().item()\n",
    "    n_val_samples = val_mask.sum().item()\n",
    "    n_test_samples = test_mask.sum().item()\n",
    "\n",
    "    print(\"\"\"----Data statistics------'\n",
    "      #Edges %d\n",
    "      #Classes %d\n",
    "      #Train samples %d\n",
    "      #Val samples %d\n",
    "      #Test samples %d\"\"\" %\n",
    "          (n_edges, n_classes,\n",
    "              n_train_samples,\n",
    "              n_val_samples,\n",
    "              n_test_samples))\n",
    "\n",
    "    # create GCN model\n",
    "    g = DGLGraph(data.graph, readonly=True)\n",
    "    norm = 1. / g.in_degrees().float().unsqueeze(1)\n",
    "\n",
    "    if args.gpu < 0:\n",
    "        cuda = False\n",
    "    else:\n",
    "        cuda = True\n",
    "        torch.cuda.set_device(args.gpu)\n",
    "        features = features.cuda()\n",
    "        labels = labels.cuda()\n",
    "        train_mask = train_mask.cuda()\n",
    "        val_mask = val_mask.cuda()\n",
    "        test_mask = test_mask.cuda()\n",
    "        norm = norm.cuda()\n",
    "\n",
    "    g.ndata['features'] = features\n",
    "\n",
    "    num_neighbors = args.num_neighbors\n",
    "\n",
    "    g.ndata['norm'] = norm\n",
    "\n",
    "    model = GCNSampling(in_feats,\n",
    "                        args.n_hidden,\n",
    "                        n_classes,\n",
    "                        args.n_layers,\n",
    "                        F.relu,\n",
    "                        args.dropout)\n",
    "\n",
    "    if cuda:\n",
    "        model.cuda()\n",
    "\n",
    "    loss_fcn = nn.CrossEntropyLoss()\n",
    "\n",
    "    infer_model = GCNInfer(in_feats,\n",
    "                           args.n_hidden,\n",
    "                           n_classes,\n",
    "                           args.n_layers,\n",
    "                           F.relu)\n",
    "\n",
    "    if cuda:\n",
    "        infer_model.cuda()\n",
    "\n",
    "    # use optimizer\n",
    "    optimizer = torch.optim.Adam(model.parameters(),\n",
    "                                 lr=args.lr,\n",
    "                                 weight_decay=args.weight_decay)\n",
    "\n",
    "    # initialize graph\n",
    "    dur = []\n",
    "    for epoch in range(args.n_epochs):\n",
    "        for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,\n",
    "                                                       args.num_neighbors,\n",
    "                                                       neighbor_type='in',\n",
    "                                                       shuffle=True,\n",
    "                                                       num_workers=32,\n",
    "                                                       num_hops=args.n_layers+1,\n",
    "                                                       seed_nodes=train_nid):\n",
    "            nf.copy_from_parent()\n",
    "            model.train()\n",
    "            # forward\n",
    "            pred = model(nf)\n",
    "            batch_nids = nf.layer_parent_nid(-1).to(device=pred.device, dtype=torch.long)\n",
    "            batch_labels = labels[batch_nids]\n",
    "            loss = loss_fcn(pred, batch_labels)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        for infer_param, param in zip(infer_model.parameters(), model.parameters()):\n",
    "            infer_param.data.copy_(param.data)\n",
    "\n",
    "        num_acc = 0.\n",
    "\n",
    "        for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,\n",
    "                                                       g.number_of_nodes(),\n",
    "                                                       neighbor_type='in',\n",
    "                                                       num_workers=32,\n",
    "                                                       num_hops=args.n_layers+1,\n",
    "                                                       seed_nodes=test_nid):\n",
    "            nf.copy_from_parent()\n",
    "            infer_model.eval()\n",
    "            with torch.no_grad():\n",
    "                pred = infer_model(nf)\n",
    "                batch_nids = nf.layer_parent_nid(-1).to(device=pred.device, dtype=torch.long)\n",
    "                batch_labels = labels[batch_nids]\n",
    "                num_acc += (pred.argmax(dim=1) == batch_labels).sum().cpu().item()\n",
    "\n",
    "        print(\"Test Accuracy {:.4f}\". format(num_acc/n_test_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ArgsClass():\n",
    "    def __init__(self):\n",
    "        self.dropout = 0.5\n",
    "        self.gpu = -10\n",
    "        self.lr = 3e-2\n",
    "        self.n_epochs = 200\n",
    "        self.batch_size = 1000\n",
    "        self.test_batch_size = 1000\n",
    "        self.num_neighbors = 3\n",
    "        self.n_hidden = 16\n",
    "        self.n_layers = 1\n",
    "        self.self_loop = 'store_true'\n",
    "        self.weight_decay = 5e-4\n",
    "        self.dataset = 'cora'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<__main__.ArgsClass object at 0x000002155FFEB668>\n"
     ]
    }
   ],
   "source": [
    "args = ArgsClass()\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----Data statistics------'\n",
      "      #Edges 13264\n",
      "      #Classes 7\n",
      "      #Train samples 140\n",
      "      #Val samples 300\n",
      "      #Test samples 1000\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "module 'torch._C' has no attribute '_cuda_setDevice'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-19-b3f47c9a9b7f>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-10-97bb0f4ec5af>\u001b[0m in \u001b[0;36mmain\u001b[1;34m(args)\u001b[0m\n\u001b[0;32m     41\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     42\u001b[0m         \u001b[0mcuda\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 43\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mset_device\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgpu\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     44\u001b[0m         \u001b[0mfeatures\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     45\u001b[0m         \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\torch\\cuda\\__init__.py\u001b[0m in \u001b[0;36mset_device\u001b[1;34m(device)\u001b[0m\n\u001b[0;32m    260\u001b[0m     \"\"\"\n\u001b[0;32m    261\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mdevice\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 262\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cuda_setDevice\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    263\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    264\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: module 'torch._C' has no attribute '_cuda_setDevice'"
     ]
    }
   ],
   "source": [
    "main(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
